import os
import cv2
import numpy as np
from faceRecognition.utils import utils
from faceRecognition.net.inception import InceptionResNetV1
from faceRecognition.net.mtcnn import mtcnn

class FaceRec():
    def __init__(self):
        print("正在加载mtcnn模型")
        self.mtcnn_model = mtcnn()
        self.threshold = [0.5, 0.6, 0.8]  # pnet,rnet,onet
        print("正在加载facenet模型")
        self.facenet_model = InceptionResNetV1()
        model_path = 'faceRecognition/model_data/facenet_keras.h5'
        self.facenet_model.load_weights(model_path)

    def get_rectangles(self, img):
        fileNPArray = np.fromstring(img, np.uint8)
        img = cv2.imdecode(fileNPArray, cv2.IMREAD_COLOR)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        rectangles = self.mtcnn_model.detectFace(img, self.threshold).tolist()
        return rectangles

    def get_features(self, img):
        fileNPArray = np.fromstring(img, np.uint8)
        img = cv2.imdecode(fileNPArray, cv2.IMREAD_COLOR)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        rectangles = self.mtcnn_model.detectFace(img, self.threshold)
        rectangles = utils.rect2square(np.array(rectangles))
        rectangle = rectangles[0]
        landmark = np.reshape(rectangle[5:15], (5, 2)) - np.array([int(rectangle[0]), int(rectangle[1])])
        crop_img = img[int(rectangle[1]):int(rectangle[3]), int(rectangle[0]):int(rectangle[2])]
        crop_img, _ = utils.Alignment_1(crop_img, landmark)
        crop_img = np.expand_dims(cv2.resize(crop_img, (160, 160)), 0)
        face_encoding = utils.calc_128_vec(self.facenet_model, crop_img).tolist()
        return face_encoding

    def compare_features(self, features1, features2):
        features1 = np.array(features1)
        features2 = np.array(features2)
        face_distances = utils.face_distance(features1, features2).tolist()
        return face_distances

    def compareAll_features(self, featuresList, features):
        featuresList = np.array(featuresList)
        features = np.array(features)
        face_distances = utils.face_distances(featuresList, features).tolist()
        return face_distances
if __name__ == '__main__':
    faceRec = FaceRec()